{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "NestedTensor as unifying datastructure for non-uniform Tensor input",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WhQTZmQY6g4c"
      },
      "source": [
        "## NestedTensor as unifying datastructure for non-uniform Tensor input\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z6sn7kkU6jV1"
      },
      "source": [
        "See [the corresponding RFC for more background on motivation](https://docs.google.com/document/d/1VdKG5JA0U8iiwd6eYpUlCItm3zNJns8_ooJvaH_JWV8/edit#).\n",
        "\n",
        "In general this construct is meant as a container with the following layouts as inspired by the cited operators."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GB8aHyCn1xHc"
      },
      "source": [
        "from enum import Enum\n",
        "class Layout(Enum):\n",
        "    Masked = 0 # Example: TransformerEncoderLayer or CrossEntropyLoss by using the mask to fill with padding_idx\n",
        "    Packed = 1 # Example: EmbeddingBag\n",
        "    PackedSequence = 2 # Restricted to RNN\n",
        "    List = 3 # Fallback and default for quick creation"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oYT9BMcr1_Ag"
      },
      "source": [
        "The following hidden cell is an incomplete implementation of this using torch_function. This structure does layout conversions via a ```to``` method and provides a unified constructor, which accepts a list of Tensors and that allows the specification of a layout."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eK-_QTN36iIF",
        "cellView": "form"
      },
      "source": [
        "#@title\n",
        "import torch\n",
        "from enum import Enum\n",
        "\n",
        "def _nn_functional_embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,\n",
        "                                 scale_grad_by_freq=False, mode='mean', sparse=False,\n",
        "                                 per_sample_weights=None, include_last_offset=False):\n",
        "    # [...] Omitted input sanitization\n",
        "    # [...] Verify that nested_size is shape compliant, i.e. all 1d Tensors (sequences)\n",
        "    # Design decision: conversion happens automatically. This is similar to how we automatically\n",
        "    # make Tensor contiguous or convert from fp16 to fp32 or sparse to dense if needed.\n",
        "    # We could decide to throw a warning here.\n",
        "    input = input.to(Layout.Packed)\n",
        "    offsets = torch.tensor([0] + [x[0] for x in input.nested_size()[:-1]]).cumsum(0)\n",
        "    # We could consider caching this metadata in NestedTensor\n",
        "    offsets = offsets.to(data.device)\n",
        "    assert input.layout is Layout.Packed\n",
        "    return torch.nn.functional.embedding_bag(\n",
        "        input.data,\n",
        "        weight,\n",
        "        offsets,\n",
        "        max_norm,\n",
        "        norm_type,\n",
        "        scale_grad_by_freq,\n",
        "        mode,\n",
        "        sparse,\n",
        "        per_sample_weights,\n",
        "        include_last_offset)\n",
        "\n",
        "def nested_tensor(tensors, layout=Layout.List, dtype=None, device=None, requires_grad=False): # pin_memory could be added as a layout\n",
        "    \"\"\"\n",
        "    Given a list of Tensors, each of the same dimension but variable shape, construct a NestedTensor that represents\n",
        "    this list of Tensors.\n",
        "\n",
        "    If a given entry of tensors does not match the dtype or device of the others, the result dtype or device needs to\n",
        "    be specified explicitly\n",
        "    \"\"\"\n",
        "    assert layout is Layout.List # No other layout support for now\n",
        "    assert isinstance(tensors, list)\n",
        "    assert len(tensors) > 0\n",
        "    dtype = tensors[0].dtype if dtype is None else dtype\n",
        "    device = tensors[0].device if device is None else device\n",
        "    # Change dtype and device if necessary\n",
        "    tensors = [t.to(device, dtype) for t in tensors]\n",
        "    nested_size = tuple(x.size() for x in tensors)\n",
        "    return NestedTensor(tensors, nested_size, Layout.List, dtype, device, requires_grad).to(layout)\n",
        "\n",
        "def _from_packed_sequence_to_list(packed_sequence):\n",
        "    padded, lengths = torch.nn.utils.rnn.pad_packed_sequence(packed_sequence, batch_first=True)\n",
        "    tensors = []\n",
        "    for i, length in enumerate(lengths):\n",
        "        tensors.append(padded[i, :length])\n",
        "    return tensors\n",
        "\n",
        "def as_nested_tensor(data, layout=Layout.List, dtype=None, device=None, requires_grad=False): # pin_memory could be added as a layout\n",
        "    \"\"\"\n",
        "    Similar to torch.as_tensor, this converts the given data into a NestedTensor.\n",
        "    \"\"\"\n",
        "    if isinstance(data, torch.nn.utils.rnn.PackedSequence):\n",
        "        return nested_tensor(_from_packed_sequence_to_list(data))\n",
        "    raise NotImplementedError(\"as_nested_tensor cannot convert data of type {} into a NestedTensor.\".format(type(data)))\n",
        "\n",
        "\n",
        "def _from_list_to_layout(list_nt, target_layout):\n",
        "    assert list_nt.layout is Layout.List\n",
        "    if target_layout is Layout.List:\n",
        "        return list_nt\n",
        "    if target_layout is Layout.Masked:\n",
        "        max_size = [len(list_nt.data)]\n",
        "        for d in range(list_nt.data[0].dim()):\n",
        "            max_size.append(max(x.size(d) for x in list_nt.data))\n",
        "        # This approach doesn't support autograd and can also be used during construction or without autograd\n",
        "        # An approach that does work with autograd uses pad and cat, but is a bit more involved\n",
        "        # See https://github.com/pytorch/nestedtensor/blob/master/nestedtensor/nested/masking.py#L142 for a complete implementation\n",
        "        data = torch.zeros(*max_size, dtype=list_nt.dtype, device=list_nt.device)\n",
        "        mask = torch.zeros(*max_size, dtype=torch.bool, device=list_nt.device)\n",
        "        for d_t, d_m, t in zip(data, mask, list_nt.data):\n",
        "            for d in range(t.dim()):\n",
        "                d_t = d_t.narrow(d, 0, t.size(d))\n",
        "                d_m = d_m.narrow(d, 0, t.size(d))\n",
        "            d_t.copy_(t.detach())\n",
        "            d_m.fill_(1)\n",
        "        return NestedTensor(data, list_nt.nested_size(), Layout.Masked, list_nt.dtype, list_nt.device, list_nt.requires_grad, metadata=mask)\n",
        "    if target_layout is Layout.Packed:\n",
        "        offsets_ = list_nt.nested_size()\n",
        "        data = torch.cat([x.reshape(-1) for x in list_nt.data]) # shape information is stored in nested_size\n",
        "        return NestedTensor(data, list_nt.nested_size(), Layout.Packed, list_nt.dtype, list_nt.device, list_nt.requires_grad)\n",
        "    if target_layout is Layout.PackedSequence:\n",
        "        return NestedTensor(torch.nn.utils.rnn.pack_sequence(list_nt.data, enforce_sorted=False), # enforce_sorted set to False doesn't support ONNX for now,\n",
        "                            list_nt.nested_size(),\n",
        "                            Layout.PackedSequence,\n",
        "                            list_nt.dtype,\n",
        "                            list_nt.device,\n",
        "                            list_nt.requires_grad)\n",
        "    raise NotImplemented(\"Converstion from list to target layout {} not supported.\".format(target_layout.name))\n",
        "            \n",
        "class NestedTensor(object):\n",
        "    def __init__(self, data, nested_size, layout, dtype, device, requires_grad, metadata=None):\n",
        "        # Can be list of tensors, single packed or masked Tensor or PackedSequence\n",
        "        self.data = data\n",
        "        # Metadata is overloaded with type and meaning\n",
        "        # Masked: Stores bool mask where True means included, False means excluded\n",
        "        # Packed: Stores 1d Tensor of offsets. offsets are the length of each entry in the flat data. Packed currently only supports 2d NestedTensors\n",
        "        # PackedSequence: Stores the lengths of the PackedSequence\n",
        "        self.metadata = metadata\n",
        "        self._nested_size = nested_size\n",
        "        self._layout = layout\n",
        "        self._dtype = dtype\n",
        "        self._device = device\n",
        "        # Gradient is supported by differentiable layout conversion functions a tracked by data field\n",
        "        self._requires_grad = requires_grad \n",
        "\n",
        "    def __torch_function__(self, func, types, args=(), kwargs=None):\n",
        "        if func is torch.nn.functional.embedding_bag:\n",
        "            # Design decision pending: We could make conversion to Layout.Padding automatic\n",
        "            return _nn_functional_embedding_bag(*args, **kwargs)\n",
        "        raise NotImplementedError(\"Given func {} does not support NestedTensor.\".format(func))\n",
        "\n",
        "    def nested_size(self):\n",
        "        return self._nested_size\n",
        "\n",
        "    @property\n",
        "    def dtype(self):\n",
        "        return self._dtype\n",
        "\n",
        "    @property\n",
        "    def layout(self):\n",
        "        return self._layout\n",
        "\n",
        "    @property\n",
        "    def device(self):\n",
        "        return self._device\n",
        "\n",
        "    @property\n",
        "    def requires_grad(self):\n",
        "        return self._requires_grad\n",
        "\n",
        "    # There are 5 layouts, therefore there are 20 possible\n",
        "    # conversions excluding identities\n",
        "    def to(self, target_layout):\n",
        "        assert isinstance(target_layout, Layout)\n",
        "        if self.layout is target_layout:\n",
        "            return self\n",
        "        if self.layout is Layout.List:\n",
        "            return _from_list_to_layout(self, target_layout)\n",
        "        raise NotImplementedError(\n",
        "            \"Cannot convert {} to desired layout {}\".format(\n",
        "                self.layout.name, target_layout.name))\n",
        "\n",
        "    \n",
        "    def to_tensor_list(self):\n",
        "        # Returns a list of Tensors\n",
        "        return self.to(Layout.List).data\n",
        "\n",
        "    def to_padded(self, padding_value=-1):\n",
        "        # Returns a Tensor padded with padding_value\n",
        "        converted = self.to(Layout.Masked)\n",
        "        return converted.data.masked_fill_(~converted.metadata, padding_value)\n",
        "\n",
        "    def to_masked(self):\n",
        "        # Returns a Tensor plus a Bool mask of same shape\n",
        "        converted = self.to(Layout.Masked)\n",
        "        return converted.data, converted.mask\n",
        "\n",
        "    def to_packed_sequence(self):\n",
        "        return self.to(Layout.PackedSequence).data\n",
        "              "
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PtvOXIbaCgn0"
      },
      "source": [
        "Let's step through an intended usecase and compare it a current application.\n",
        "\n",
        "The following EmbeddingBag represents a lookupt table of 10 vectors, each of dimension 3."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MaXisU5zAOsI"
      },
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "embedding_bag = nn.EmbeddingBag(10, 3)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z8rk3SfMC1Xu"
      },
      "source": [
        "Let's construct a list of tensors filled with a varying degree of word ids and feed it into EmbeddingBag as we were to right now."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YXl9dk-lDoQS"
      },
      "source": [
        "sentences = [torch.tensor([0, 3, 1]), torch.tensor([5, 1, 2, 4]), torch.tensor([3, 2])]"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Cf6bagzvCyu8",
        "outputId": "b5656700-8af6-463c-e286-b5500d3f6626"
      },
      "source": [
        "data = torch.cat(sentences)\n",
        "offsets = torch.tensor([0] + [len(x) for x in sentences[:-1]]).cumsum(0)\n",
        "print(offsets)\n",
        "print(embedding_bag(data, offsets))"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([0, 3, 7])\n",
            "tensor([[-0.0482,  0.0242, -0.6505],\n",
            "        [-0.6074,  0.6866, -0.4335],\n",
            "        [ 0.5125, -0.1862, -0.8296]], grad_fn=<EmbeddingBagBackward>)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cefuO5twDi3a"
      },
      "source": [
        "And this is what it'll look like with NestedTensor"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Oznb-50zDXSY",
        "outputId": "22581039-15b6-4296-ca30-4c7a465b287c"
      },
      "source": [
        "nt = nested_tensor(sentences)\n",
        "print(nt.nested_size())\n",
        "embedding_bag(nt)"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "(torch.Size([3]), torch.Size([4]), torch.Size([2]))\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[-0.0482,  0.0242, -0.6505],\n",
              "        [-0.6074,  0.6866, -0.4335],\n",
              "        [ 0.5125, -0.1862, -0.8296]], grad_fn=<EmbeddingBagBackward>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lTg0ePVkcI_L"
      },
      "source": [
        "Is it going to be less efficient to first construct a NestedTensor and then convert into an operator specific data structure? If we do this automatically we have the chance of optimizing a conversion, but we also run the risk of converting prematurely or in an inefficient way. This is the usual lazy vs. eager tradeoff and the current PyTorch convention seem to lean towards automatic conversion (e.g. when given non-contiguous inputs, sparse inputs (usually) or inputs of other dtype)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0VHEwOBAgpQX",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "1d1021f1-52c7-4de0-f578-739886aec073"
      },
      "source": [
        "print(nt.to_padded())"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[ 0,  3,  1, -1],\n",
            "        [ 5,  1,  2,  4],\n",
            "        [ 3,  2, -1, -1]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Uv6gfUAriXd_",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "69454e61-91e2-4716-f01a-8a46fbe9255b"
      },
      "source": [
        "print(nt.to_tensor_list())"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[tensor([0, 3, 1]), tensor([5, 1, 2, 4]), tensor([3, 2])]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2faLBjWOxGbl",
        "outputId": "804f3be9-94af-4c9a-bf5f-baabc3cb072c"
      },
      "source": [
        "rnn = nn.RNN(5, #embedding dimension\n",
        "             3, 2)\n",
        "h0 = torch.randn(2, 3, 3)\n",
        "embeddings = [s.unsqueeze(1).repeat(1, 5) #emulating embedding\n",
        "              for s in sentences]\n",
        "nt = nested_tensor(embeddings, dtype=torch.float)\n",
        "\n",
        "try:\n",
        "  rnn(nt) # \n",
        "except AttributeError as e:\n",
        "  print(e)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "'NestedTensor' object has no attribute 'size'\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "17hEaYfqFEil"
      },
      "source": [
        "RNN doesn't have good torch_function support, but luckily we can just convert manually into the desired format."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3ud7ra0qE9L1",
        "outputId": "2d90c9fc-b741-4891-eae6-c83358ec0aa3"
      },
      "source": [
        "ps = nt.to_packed_sequence()\n",
        "output, hn = rnn(ps, h0)\n",
        "print(output)\n"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "PackedSequence(data=tensor([[ 0.1349, -0.1506,  0.8108],\n",
            "        [ 0.6356,  0.2794,  0.7581],\n",
            "        [-0.1012,  0.3027,  0.9623],\n",
            "        [ 0.3990, -0.0811,  0.6990],\n",
            "        [ 0.0292, -0.2913,  0.7972],\n",
            "        [ 0.3070, -0.4692,  0.7617],\n",
            "        [ 0.2164, -0.0570,  0.7273],\n",
            "        [ 0.4771,  0.0845,  0.6256],\n",
            "        [-0.0036, -0.2968,  0.7427]], grad_fn=<CatBackward>), batch_sizes=tensor([3, 3, 2, 1]), sorted_indices=tensor([1, 0, 2]), unsorted_indices=tensor([1, 0, 2]))\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mZhjO8n9b9HS"
      },
      "source": [
        "And now we use the as_nested_tensor function (similar to torch.as_tensor) to interpret the resulting value (which is also a PackedSequence) as a NestedTensor again. This is useful in particular when you're about to feed this output into a linear layer as your final projection before the loss, because you can retrieve the padded version of your output."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Hzq_EIkvb986",
        "outputId": "ed1715cf-bd78-4752-b4a8-9dad0410ee79"
      },
      "source": [
        "output_nt = as_nested_tensor(output)\n",
        "padded_output = output_nt.to_padded(0)\n",
        "print(padded_output.size())\n",
        "print(padded_output)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "torch.Size([3, 4, 3])\n",
            "tensor([[[ 0.6356,  0.2794,  0.7581],\n",
            "         [ 0.0292, -0.2913,  0.7972],\n",
            "         [ 0.4771,  0.0845,  0.6256],\n",
            "         [ 0.0000,  0.0000,  0.0000]],\n",
            "\n",
            "        [[ 0.1349, -0.1506,  0.8108],\n",
            "         [ 0.3990, -0.0811,  0.6990],\n",
            "         [ 0.2164, -0.0570,  0.7273],\n",
            "         [-0.0036, -0.2968,  0.7427]],\n",
            "\n",
            "        [[-0.1012,  0.3027,  0.9623],\n",
            "         [ 0.3070, -0.4692,  0.7617],\n",
            "         [ 0.0000,  0.0000,  0.0000],\n",
            "         [ 0.0000,  0.0000,  0.0000]]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "04IG1a8QdFy4"
      },
      "source": [
        "loss = nn.NLLLoss()"
      ],
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "w6if_rHweF1b",
        "outputId": "54a91253-6c99-45cf-9f6e-431eac595591"
      },
      "source": [
        "targets = torch.tensor([1, 2, 1, -100, 2, 1, 1, 2, 1, 1, -100, -100])\n",
        "loss(padded_output.reshape(-1, 3), targets)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor(-0.2678)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    }
  ]
}